package com.gitee.dufafei.spark.sql

import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, Or, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{RuntimeConfig, SparkSession}

/**
 * CheckPartitionTable规则执行类，需要通过引入sparkSession从而获取到引入conf。
 */
case class CheckPartitionTable(sparkSession: SparkSession) extends Rule[LogicalPlan] with PredicateHelper {

  // 是否检查分区，配置
  val check_partition = "spark.table.check.partition"
  // 检查分区，限制分区读取数量，配置
  val check_num_partition = "spark.table.check.partition.num"
  val conf: RuntimeConfig = sparkSession.conf

  /**
   * 判断是否是分区表，且是否添加分区字段。
   */
  def isPartitionTable(filter: Filter,numPartition:Int): Boolean = {
    var boolean = false
    filter.child match {
      // 匹配logicalRelation
      case logicalRelation@LogicalRelation(
      fsRelation@HadoopFsRelation(location: CatalogFileIndex, partitionSchema: StructType, _, _, _, _),
      _,
      catalogTable,
      false) =>
        val table = catalogTable.get
        // 判断读取表是否存在分区column
        if (table.partitionColumnNames.nonEmpty) {
          val sparkSession = fsRelation.sparkSession
          // 获取表的分区column的Attribute
          val partitionColumns =
            logicalRelation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver)
          log.info("partitionColumns : " + partitionColumns)
          val partitionSet = AttributeSet(partitionColumns)
          //获取分区Filter表达式
          val partitionKeyFilters = splitPredicates(filter.condition,partitionSet)
          var partition_size = -1L
          if (partitionKeyFilters.nonEmpty) {
            log.info("partitionKeyFiltersExpression:" + partitionKeyFilters)
            //在hdfs上获取分区path
            val prunedFileIndex = location.filterPartitions(partitionKeyFilters)
            val partitions = prunedFileIndex.partitionSpec().partitions
            partition_size = partitions.size
            log.info("partitions : " + partitions)
          }
          boolean = partitionKeyFilters.isEmpty || partition_size > numPartition
        }
      // 匹配 CatalogRelation
     /* case catalogRelation: CatalogRelation =>
        val partitionSet = AttributeSet(catalogRelation.partitionCols)
        val partitionKeyFilters = splitPredicates(filter.condition,partitionSet)
        // 判断是否存在分区属性
        boolean = partitionKeyFilters.forall(_.references.subsetOf(partitionSet))*/
      case _ => log.warn("未获取到表信息")
    }
    boolean
  }

  /**
   * 分离分区谓词，得到分区谓词表达式。
   * 在sql解析过程中将谓词解析为TreeNode，此处采用递归的方式获取分区谓词。
   */
  def splitPredicates(condition: Expression, partitionSet: AttributeSet): Seq[Expression] = {
    condition match {
      //匹配and表达式，并筛选and表达式中的分区表达式
      case And(cond1, cond2) =>
        splitPredicates(cond1,partitionSet) ++ splitPredicates(cond2,partitionSet)
      //匹配or表达式，并筛选or表达式中的分区表达式
      case Or(cond1, cond2)=>
        val leftSeq = splitPredicates(cond1,partitionSet)
        val rightSeq = splitPredicates(cond2,partitionSet)
        if(leftSeq.nonEmpty && rightSeq.nonEmpty)
          Or(leftSeq.reduceLeft(And),rightSeq.reduceLeft(And)) :: Nil
        else Nil
      case other  => if (other.references.subsetOf(partitionSet)) other :: Nil else Nil
    }
  }

  def apply(plan: LogicalPlan): LogicalPlan = {
    if (!conf.get(check_partition, "true").toBoolean) {
      log.warn(s"Partition check is not enabled.")
      plan
    }
    else plan transform {
      case j@Filter(condition: Expression, child: LogicalPlan)
        if isPartitionTable(j, conf.get(check_num_partition,s"${Int.MaxValue}").toInt) =>
        throw new Exception(s"""${condition.sql} ${child.treeString} No partition information is added to the partition table""".stripMargin)
    }
  }
}
